import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D

# ------------------------
# Configuration
# ------------------------
num_nodes = 8
slots_per_node = 4
sample_rate = 200          # samples per second
lfo_rate = 0.5             # Hz for analog oscillation
amplitude_scale = 0.1
noise_scale = 0.02
num_bands = 3
band_freqs = [100, 200, 300]

# Environmental AM/FM signals
am_freq = 50
fm_freq = 70
fm_dev = 5

# Morph parameter (0 = polar, 1 = cartesian)
morph = 0.0

# ------------------------
# Initialize HDGL lattice
# ------------------------
node_lattices = np.zeros((num_nodes, slots_per_node))
node_phases = np.zeros((num_nodes, slots_per_node))

# ------------------------
# Lattice evolution
# ------------------------
def evolve_lattice(lattice, phases, t):
    composite = 0.0
    for node_idx in range(lattice.shape[0]):
        for slot_idx in range(lattice.shape[1]):
            lfo = np.sin(2*np.pi*lfo_rate*t + node_idx + slot_idx)
            delta = amplitude_scale*lfo + noise_scale*np.random.randn()
            lattice[node_idx, slot_idx] += delta
            phases[node_idx, slot_idx] += delta
            for f in band_freqs:
                composite += lattice[node_idx, slot_idx] * np.sin(2*np.pi*f*t + phases[node_idx, slot_idx])
    composite /= (lattice.size * len(band_freqs))

    # AM and FM carriers
    am_signal = np.sin(2*np.pi*am_freq*t) * (1 + 0.5*composite)
    fm_signal = np.sin(2*np.pi*(fm_freq + fm_dev*composite)*t)
    
    return 0.5*composite + 0.25*am_signal + 0.25*fm_signal

# ------------------------
# Setup 3D plotting
# ------------------------
plt.style.use('dark_background')
fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim(-1,1)
ax.set_ylim(-1,1)
ax.set_zlim(-1,1)

points = ax.scatter([], [], [], c=[], cmap='viridis', s=50)

time_data = []
signal_data = []

t_global = 0.0
dt = 1.0 / sample_rate

# ------------------------
# Update function for animation
# ------------------------
def update(frame):
    global t_global, morph
    composite = evolve_lattice(node_lattices, node_phases, t_global)
    # Generate radial layout for nodes
    theta = np.linspace(0, 2*np.pi, num_nodes, endpoint=False)
    r = np.linspace(0.2, 1.0, slots_per_node)
    THETA, R = np.meshgrid(theta, r)
    Z = np.tile(composite, (slots_per_node,1))
    
    # Morph polar -> cartesian
    X = R * np.cos(THETA) * (1-morph) + R*(1-morph)*0 + R*morph
    Y = R * np.sin(THETA) * (1-morph) + R*(1-morph)*0 + R*morph
    
    # Flatten for scatter
    Xf = X.flatten()
    Yf = Y.flatten()
    Zf = Z.flatten()
    
    colors = Zf
    points._offsets3d = (Xf, Yf, Zf)
    points.set_array(colors)
    
    # Advance time
    t_global += dt
    # Slowly morph from polar to cartesian
    morph = min(1.0, morph + 0.002)
    return points,

ani = FuncAnimation(fig, update, interval=dt*1000, blit=False)
plt.show()
